{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from pathlib import Path\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from hfnet.evaluation.image_retrieval import compute_recall, is_gt_match_2D\n",
    "from hfnet.datasets.nclt import Nclt\n",
    "from hfnet.settings import DATA_PATH, EXPER_PATH\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(seq, experiment):\n",
    "    im_poses = Nclt.get_pose_file(seq)\n",
    "    descriptors = []\n",
    "    for t in im_poses['time']:\n",
    "        with np.load(Path(EXPER_PATH, 'exports', experiment, seq, f'{t}.npz')) as npz:\n",
    "            descriptors.append(npz['global_descriptor'].copy())\n",
    "    return im_poses, np.array(descriptors)\n",
    "\n",
    "def nclt_recall(ref_seq, query_seqs, experiment, distance_thresh=10, angle_thresh=np.pi/2, *arg, **kwarg):\n",
    "    ref_poses, ref_descriptors = get_data(ref_seq, experiment)\n",
    "    query_poses, query_descriptors = [], []\n",
    "    for s in query_seqs:\n",
    "        poses, descriptors = get_data(s, experiment)\n",
    "        query_poses.append(poses)\n",
    "        query_descriptors.append(descriptors)\n",
    "    query_poses = np.concatenate(query_poses, axis=0)\n",
    "    query_descriptors = np.concatenate(query_descriptors, axis=0)\n",
    "    gt_matches = is_gt_match_2D(query_poses, ref_poses, distance_thresh, angle_thresh)\n",
    "    return compute_recall(ref_descriptors, query_descriptors, gt_matches, *arg, **kwarg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiments = [\n",
    "    'netvlad/nclt', \n",
    "    # experiments\n",
    "]\n",
    "ref_seq = '2012-01-08'\n",
    "query_seqs = ['2013-02-23', '2012-08-20']\n",
    "\n",
    "plt.figure(dpi=150)\n",
    "colors = plt.rcParams['axes.prop_cycle'].by_key()['color']\n",
    "for e, c in zip(experiments, colors):\n",
    "    m = nclt_recall(ref_seq, query_seqs, e, max_num_nn=30, pca_dim=0)\n",
    "    plt.plot(1+np.arange(len(m)), 100*m, label=e, color=c, linewidth=1);\n",
    "    print(f'{e:<70} Recall@10: {m[9]:.3f}')\n",
    "    \n",
    "#    m = nclt_recall(ref_seq, query_seqs, e, max_num_nn=30, pca_dim=512)\n",
    "#    plt.plot(1+np.arange(len(m)), 100*m, label=e+'_proj512', \n",
    "#             color=c, linewidth=1.3, linestyle='--');\n",
    "#    print(f'{e+\"_proj512\":<70} Recall@10: {m[9]:.3f}')\n",
    "\n",
    "plt.xticks([1]+np.arange(10, 31, step=10).tolist()); plt.grid(color=[0.85]*3);\n",
    "plt.legend(loc=9, bbox_to_anchor=(0.5, -0.2));\n",
    "plt.xlabel('Number of queried neighbors'), plt.ylabel('Recall@N (%)');"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
